package utils

import (
	"crypto"
	"crypto/rand"
	"crypto/rsa"
	"crypto/sha256"
	"crypto/x509"
	"encoding/base64"
	"encoding/pem"
	"fmt"
	"strings"
)

type RsaKeys struct {
	PublicKeyStr  string
	PrivateKeyStr string
	PublicKey     *rsa.PublicKey
	PrivateKey    *rsa.PrivateKey
}

var rsaWrongKey = fmt.Errorf("wrong key")

// RsaGenKeys 生成密钥对 bits = 1024/2048
func RsaGenKeys(bits int) (publicKey, privateKey []byte, err error) {
	rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, bits)
	if err != nil {
		return
	}
	block := &pem.Block{
		Type:  "RSA PRIVATE KEY",
		Bytes: x509.MarshalPKCS1PrivateKey(rsaPrivateKey),
	}
	privateKey = pem.EncodeToMemory(block)
	rsaPublicKey := rsaPrivateKey.PublicKey
	block = &pem.Block{
		Type:  "RSA PUBLIC KEY",
		Bytes: x509.MarshalPKCS1PublicKey(&rsaPublicKey),
	}
	publicKey = pem.EncodeToMemory(block)
	return
}

// NewRsaKeys
// openssl genrsa -out rsa_private_key.pem 2048
// openssl rsa -in rsa_private_key.pem -pubout -out rsa_public_key.pem
func NewRsaKeys(publicKey, privateKey string) (*RsaKeys, error) {
	var rsaPublicKey *rsa.PublicKey
	var rsaPrivateKey *rsa.PrivateKey
	var err error
	if publicKey != "" {
		block, _ := pem.Decode([]byte(publicKey))
		if block == nil {
			return nil, rsaWrongKey
		}
		rsaPublicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
		if err != nil {
			if strings.Contains(err.Error(), "ParsePKIXPublicKey") {
				if key, err := x509.ParsePKIXPublicKey(block.Bytes); err != nil {
					return nil, err
				} else {
					rsaPublicKey = key.(*rsa.PublicKey)
				}
			} else {
				return nil, err
			}
		}
	}
	if privateKey != "" {
		block, _ := pem.Decode([]byte(privateKey))
		if block == nil {
			return nil, rsaWrongKey
		}
		rsaPrivateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
		if err != nil {
			if strings.Contains(err.Error(), `ParsePKCS8PrivateKey`) {
				if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err != nil {
					return nil, err
				} else {
					rsaPrivateKey = key.(*rsa.PrivateKey)
				}
			} else {
				return nil, err
			}
		}
	}
	return &RsaKeys{
		PublicKeyStr:  publicKey,
		PublicKey:     rsaPublicKey,
		PrivateKeyStr: privateKey,
		PrivateKey:    rsaPrivateKey,
	}, nil
}

func (rk *RsaKeys) EncodeWithPublicKey(txt string) (string, error) {
	if bs, err := rsa.EncryptPKCS1v15(rand.Reader, rk.PublicKey, []byte(txt)); err != nil {
		return "", err
	} else {
		return base64.StdEncoding.EncodeToString(bs), nil
	}
}

func (rk *RsaKeys) DecodeWithPrivateKey(txt string) (string, error) {
	if decoded, err := base64.StdEncoding.DecodeString(txt); err != nil {
		return "", err
	} else if bs, err := rsa.DecryptPKCS1v15(rand.Reader, rk.PrivateKey, decoded); err != nil {
		return "", err
	} else {
		return string(bs), nil
	}
}

func (rk *RsaKeys) SignWithPrivateKey(txt string) (string, error) {
	shaNew := crypto.SHA256.New()
	shaNew.Write([]byte(txt))
	if bs, err := rsa.SignPKCS1v15(rand.Reader, rk.PrivateKey, crypto.SHA256, shaNew.Sum(nil)); err != nil {
		return "", err
	} else {
		return base64.StdEncoding.EncodeToString(bs), nil
	}
}

func (rk *RsaKeys) VerifyWithPublicKey(sign, txt string) error {
	signed, err := base64.StdEncoding.DecodeString(sign)
	if err != nil {
		return err
	}
	hashed := sha256.Sum256([]byte(txt))
	return rsa.VerifyPKCS1v15(rk.PublicKey, crypto.SHA256, hashed[:], signed)
}
